{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TVM 自动量化校准 C++ 代码解读\n",
    "\n",
    "参考：`tvm/src/relay/quantize/calibrate.cc` 和 `tvm/python/tvm/relay/quantize/_calibrate.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "````{dropdown}\n",
    "```c++\n",
    "// KL divergence minimization code is adapted from MXNet.\n",
    "// The original one is in incubator-mxnet/src/operator/quantization/calibrate.cc\n",
    "static std::vector<float> SmoothDistribution(const std::vector<float>& p,\n",
    "                                             const float eps = 0.0001) {\n",
    "  std::vector<size_t> is_zeros(p.size());\n",
    "  std::vector<size_t> is_nonzeros(p.size());\n",
    "  {\n",
    "    auto it = p.begin();\n",
    "    std::generate(is_zeros.begin(), is_zeros.end(),\n",
    "                  [&it]() { return static_cast<size_t>(*(it++) == 0.f); });\n",
    "  }\n",
    "  {\n",
    "    auto it = p.begin();\n",
    "    std::generate(is_nonzeros.begin(), is_nonzeros.end(),\n",
    "                  [&it]() { return static_cast<size_t>(*(it++) != 0.f); });\n",
    "  }\n",
    "  size_t n_zeros = std::accumulate(is_zeros.begin(), is_zeros.end(), 0);\n",
    "  size_t n_nonzeros = p.size() - n_zeros;\n",
    "  if (!n_nonzeros) {\n",
    "    // The discrete probability distribution is malformed. All entries are 0.\n",
    "    return std::vector<float>();\n",
    "  }\n",
    "  float eps1 = eps * static_cast<float>(n_zeros) / static_cast<float>(n_nonzeros);\n",
    "  if (eps1 >= 1.0) return std::vector<float>();\n",
    "  auto ret = p;\n",
    "  for (size_t i = 0; i < p.size(); i++) {\n",
    "    ret[i] += eps * is_zeros[i] - eps1 * is_nonzeros[i];\n",
    "  }\n",
    "  return ret;\n",
    "}\n",
    "```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码实现了平滑离散概率分布函数(SmoothDistribution)，用于最小化 KL 散度。该函数接受浮点数向量 `p` 作为输入，并返回平滑后的浮点数向量。\n",
    "\n",
    "具体实现过程如下：\n",
    "1. 首先定义两个大小为 `p.size()` 的整数向量 `is_zeros` 和 `is_nonzeros`，分别用于记录 `p` 中每个元素是否为 `0` 或非 `0`。\n",
    "2. 使用 `std::generate` 函数生成 `is_zeros` 和 `is_nonzeros` 向量，其中 `is_zeros[i]` 表示 `p[i]` 是否为 `0`，`is_nonzeros[i]` 表示 `p[i]` 是否非 `0`。\n",
    "3. 计算 `p` 中 `0` 的个数 `n_zeros` 和非 0 的个数 `n_nonzeros`。\n",
    "4. 如果 `n_nonzeros` 为 `0`，说明离散概率分布格式有误，所有元素都为 `0`，直接返回空向量。\n",
    "5. 计算 `eps1`，即 `eps` 乘以 `n_zeros` 除以 `n_nonzeros`。如果 `eps1` 大于等于 `1.0`，也直接返回空向量。\n",
    "6. 定义新的向量 `ret`，将 `p`的值复制到 `ret` 中。\n",
    "7. 遍历 `p` 中的每个元素，根据 `is_zeros` 和 `is_nonzeros` 的值对 `ret` 进行更新，最终得到平滑后的离散概率分布。\n",
    "8. 返回平滑后的向量 `ret`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "````{dropdown}\n",
    "```c++\n",
    "static float ComputeEntropy(float* p, float* q, size_t size) {\n",
    "  float p_sum = std::accumulate(p, p + size, 0.f);\n",
    "  float q_sum = std::accumulate(q, q + size, 0.f);\n",
    "  float ret = 0;\n",
    "  for (size_t i = 0; i < size; i++) {\n",
    "    ICHECK(p[i] > 0 && q[i] > 0);\n",
    "    p[i] /= p_sum;\n",
    "    q[i] /= q_sum;\n",
    "    if (p[i] && q[i]) ret += p[i] * std::log(p[i] / q[i]);\n",
    "  }\n",
    "  return ret;\n",
    "}\n",
    "```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码是一个计算信息熵的函数，输入参数为两个浮点数数组 `p` 和 `q` 以及它们的大小 `size`。函数首先计算 `p` 和 `q` 的元素之和，然后遍历数组，对每个元素进行归一化处理，并计算信息熵。最后返回计算得到的信息熵值。\n",
    "\n",
    "````{dropdown}\n",
    "```c++\n",
    "float MinimizeKL(const std::vector<int>& hist, const std::vector<float>& hist_edges, int num_bins,\n",
    "                 int num_quantized_bins) {\n",
    "  const int zero_bin_idx = num_bins / 2;\n",
    "  const int num_half_quantized_bins = num_quantized_bins / 2;\n",
    "  std::vector<float> thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f);\n",
    "  std::vector<float> divergence(thresholds.size(), 0.f);\n",
    "  std::vector<float> quantized_bins(num_quantized_bins, 0);\n",
    "  for (int i = num_quantized_bins / 2; i < zero_bin_idx + 1; ++i) {\n",
    "    const int p_bin_idx_start = zero_bin_idx - i;\n",
    "    const int p_bin_idx_stop = zero_bin_idx + i + 1;\n",
    "    thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop];\n",
    "\n",
    "    std::vector<int> sliced_nd_hist(p_bin_idx_stop - p_bin_idx_start);\n",
    "    std::vector<float> p(sliced_nd_hist.size());\n",
    "    p[0] = 0;\n",
    "    p.back() = 0;\n",
    "    for (int j = 0; j < num_bins; j++) {\n",
    "      if (j <= p_bin_idx_start) {\n",
    "        p[0] += hist[j];\n",
    "      } else if (j >= p_bin_idx_stop) {\n",
    "        p.back() += hist[j];\n",
    "      } else {\n",
    "        sliced_nd_hist[j - p_bin_idx_start] = hist[j];\n",
    "        p[j - p_bin_idx_start] = hist[j];\n",
    "      }\n",
    "    }\n",
    "    // calculate how many bins should be merged to generate quantized distribution q\n",
    "    const auto num_merged_bins = sliced_nd_hist.size() / num_quantized_bins;\n",
    "    for (int j = 0; j < num_quantized_bins; j++) {\n",
    "      const int start = j * num_merged_bins;\n",
    "      const int stop = (j + 1) * num_merged_bins;\n",
    "      quantized_bins[j] =\n",
    "          std::accumulate(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop, 0);\n",
    "    }\n",
    "    quantized_bins.back() += std::accumulate(\n",
    "        sliced_nd_hist.begin() + static_cast<int>(num_quantized_bins * num_merged_bins),\n",
    "        sliced_nd_hist.end(), 0);\n",
    "    // expand quantized_bins into p.size bins\n",
    "    std::vector<float> q(sliced_nd_hist.size(), 0);\n",
    "    for (int j = 0; j < num_quantized_bins; j++) {\n",
    "      const int start = j * num_merged_bins;\n",
    "      const int stop = (j == num_quantized_bins - 1) ? q.size() : ((j + 1) * num_merged_bins);\n",
    "      int norm = std::count_if(sliced_nd_hist.begin() + start, sliced_nd_hist.begin() + stop,\n",
    "                               [](size_t i) { return i != 0; });\n",
    "      if (norm) {\n",
    "        for (int k = start; k < stop; k++) {\n",
    "          if (p[k]) q[k] = quantized_bins[j] / norm;\n",
    "        }\n",
    "      }\n",
    "    }\n",
    "    p = SmoothDistribution(p);\n",
    "    q = SmoothDistribution(q);\n",
    "\n",
    "    if (!q.size()) {\n",
    "      divergence[i - num_half_quantized_bins] = std::numeric_limits<float>::infinity();\n",
    "    } else {\n",
    "      divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size());\n",
    "    }\n",
    "  }\n",
    "  auto min_divergence_idx =\n",
    "      std::distance(divergence.begin(), std::min_element(divergence.begin(), divergence.end()));\n",
    "  return thresholds[min_divergence_idx];\n",
    "}\n",
    "```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码是一个最小化 KL 散度的函数，输入参数为一个整数向量 hist、一个浮点数向量 hist_edges、两个整数 num_bins 和 num_quantized_bins。函数首先定义了一些变量，包括零分箱索引zero_bin_idx、半量化分箱数num_half_quantized_bins、阈值向量thresholds、发散度向量divergence和量化分箱向量quantized_bins。然后，函数遍历hist_edges中的元素，计算每个元素对应的p和q分布，并计算它们之间的KL散度。最后，函数返回具有最小 KL 散度的阈值。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "````{dropdown}\n",
    "```c++\n",
    "class StatsCollector : private ExprMutator {\n",
    " public:\n",
    "  StatsCollector() : simulated_quantize_op_(Op::Get(\"relay.op.annotation.simulated_quantize\")) {}\n",
    "\n",
    "  Expr Collect(const Expr& expr) {\n",
    "    auto new_e = this->Mutate(expr);\n",
    "    const FunctionNode* func = new_e.as<FunctionNode>();\n",
    "    ICHECK(func) << \"Input shoule be Function\";\n",
    "    Expr new_body = Tuple(std::move(profile_data_));\n",
    "    Function ret_func = WithFields(GetRef<Function>(func), FreeVars(new_body), new_body);\n",
    "\n",
    "    // We are changing the function's ret_type to an empty type. Unfortunately, Optional<Type>() is\n",
    "    // indistinguishable from NullValue<Type>(), so we can't express \"update to nullptr\" in\n",
    "    // WithFields.\n",
    "    ret_func.CopyOnWrite()->ret_type = NullValue<Type>();\n",
    "    return std::move(ret_func);\n",
    "  }\n",
    "\n",
    " private:\n",
    "  Array<Expr> profile_data_;\n",
    "  const Op& simulated_quantize_op_;\n",
    "\n",
    "  Expr VisitExpr_(const CallNode* call) {\n",
    "    Expr new_e = ExprMutator::VisitExpr_(call);\n",
    "    const CallNode* new_call = new_e.as<CallNode>();\n",
    "    ICHECK(new_call);\n",
    "    if (new_call->op == simulated_quantize_op_) {\n",
    "      auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();\n",
    "      // rewrite the annotation\n",
    "      auto new_attrs = make_object<SimulatedQuantizeAttrs>();\n",
    "      const Expr& quantize_input = new_call->args[0];                  // expression being quantized\n",
    "      auto placeholder = MakeConstantScalar(DataType::Float(32), 0.);  // unused argument\n",
    "      Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};\n",
    "      new_attrs->kind = QAnnotateKind::kQIdentity;\n",
    "      new_attrs->sign = attrs->sign;\n",
    "      new_attrs->rounding = attrs->rounding;\n",
    "      Expr identity_quantize = Call(new_call->op, new_args, Attrs{new_attrs}, {});\n",
    "\n",
    "      // add non-const expressions to profile data\n",
    "      if (attrs->kind != QAnnotateKind::kQWeight) {\n",
    "        ICHECK(!quantize_input.as<ConstantNode>());\n",
    "        profile_data_.push_back(identity_quantize);\n",
    "      }\n",
    "      return identity_quantize;\n",
    "    } else {\n",
    "      return new_e;\n",
    "    }\n",
    "  }\n",
    "};\n",
    "```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了 StatsCollector 类，它继承自 ExprMutator 类。该类的主要作用是收集表达式中的量化信息，并将这些信息存储在 `profile_data_` 数组中。\n",
    "\n",
    "在Collect函数中，首先调用Mutate函数对输入的表达式进行遍历和修改，然后将其转换为FunctionNode类型，并检查其是否为空。接着，将 `profile_data_` 数组转换为 Tuple 类型，并将其作为新的函数体。最后，将新函数的返回类型设置为 `NullValue<Type>()`，表示返回类型为空。\n",
    "\n",
    "在 ``VisitExpr_`` 函数中，首先调用 ``ExprMutator::VisitExpr_`` 函数对 ``CallNode`` 类型的节点进行处理。如果该节点算子是 ``simulated_quantize_op_``，则获取该节点的属性，并创建一个新的 ``SimulatedQuantizeAttrs`` 对象。接着，将该节点的第一个参数作为量化表达式，创建一个占位符常量，并将它们与新属性一起传递给 ``Call`` 函数，生成 ``identity_quantize`` 节点。如果该节点的属性 ``kind`` 不等于 ``kQWeight``，则将非 ``const`` 表达式添加到 ``profile_data_`` 数组中。最后，返回 ``identity_quantize`` 节点。否则，直接返回 ``new_e`` 节点。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "````{dropdown}\n",
    "```c++\n",
    "/*\n",
    " * \\brief Given an annotated graph, create a profile graph to collect profile data from the\n",
    " * calibration dataset.\n",
    " *\n",
    " * This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to\n",
    " * identity mode. The tuple is the output of the profile graph. Both input and output of this pass\n",
    " * are relay::Function.\n",
    " *\n",
    " * \\param expr The simulation graph after annotation.\n",
    " * \\return The profile graph.\n",
    " */\n",
    "Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); }\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"relay._quantize.CreateStatsCollector\").set_body_typed(CreateStatsCollector);\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"relay._quantize.FindScaleByKLMinimization\")\n",
    "    .set_body([](TVMArgs args, TVMRetValue* ret) {\n",
    "      int* hist_ptr = static_cast<int*>(static_cast<void*>(args[0]));\n",
    "      float* hist_edges_ptr = static_cast<float*>(static_cast<void*>(args[1]));\n",
    "      int num_bins = args[2];\n",
    "      int num_quantized_bins = args[3];\n",
    "      std::vector<int> hist(hist_ptr, hist_ptr + num_bins);\n",
    "      std::vector<float> hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1);\n",
    "      ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins);\n",
    "    });\n",
    "\n",
    "```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了一个名为 `StatsCollector` 的类，它继承自 `ExprMutator` 类。该类的主要作用是收集表达式中的量化信息，并将这些信息存储在 `profile_data_` `数组中。\n",
    "\n",
    "在Collect函数中，首先调用Mutate函数对输入的表达式进行遍历和修改，然后将其转换为FunctionNode类型，并检查其是否为空。接着，将profile_data_数组转换为Tuple类型，并将其作为新的函数体。最后，将新函数的返回类型设置为 `NullValue<Type>()`，表示返回类型为空。\n",
    "\n",
    "在 `VisitExpr_` 函数中，首先调用 ``ExprMutator::VisitExpr_`` 函数对CallNode类型的节点进行处理。如果该节点的算子是 `simulated_quantize_op_`，则获取该节点的属性，并创建一个新的 SimulatedQuantizeAttrs对象。接着，将该节点的第一个参数作为量化表达式，创建一个占位符常量，并将它们与新属性一起传递给Call函数，生成一个identity_quantize节点。如果该节点的属性kind不等于kQWeight，则将非const表达式添加到profile_data_数组中。最后，返回identity_quantize节点。否则，直接返回new_e节点。\n",
    "\n",
    "此外，还定义了两个全局变量：`CreateStatsCollector` 和 ``FindScaleByKLMinimization``。`CreateStatsCollector` 用于创建统计收集器，而 `FindScaleByKLMinimization` 用于通过KL最小化方法查找scale。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0m_find_scale_by_kl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0marr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mquantized_dtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'int8'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mnum_bins\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8001\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mnum_quantized_bins\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m255\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m\n",
      "Given a tensor, find the optimal threshold for quantizing it.\n",
      "The reference distribution is `q`, and the candidate distribution is `p`.\n",
      "`q` is a truncated version of the original distribution.\n",
      "\n",
      "Ref:\n",
      "http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/tvm/relay/quantize/kl_divergence.py\n",
      "\u001b[0;31mType:\u001b[0m      compiled_function"
     ]
    }
   ],
   "source": [
    "from tvm.relay.quantize.kl_divergence import _find_scale_by_kl\n",
    "\n",
    "_find_scale_by_kl??"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "频数： [  0   0   0   2   4  18  44  93 140 177 195 156 109  41  11   5   3   2\n",
      "   0   0]\n",
      "分箱边界： [-5.  -4.5 -4.  -3.5 -3.  -2.5 -2.  -1.5 -1.  -0.5  0.   0.5  1.   1.5\n",
      "  2.   2.5  3.   3.5  4.   4.5  5. ]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# 生成随机数据\n",
    "data = np.random.randn(1000)\n",
    "\n",
    "# 定义分箱边界\n",
    "bin_edges = np.linspace(-5, 5, 21)\n",
    "\n",
    "# 计算直方图\n",
    "hist, bin_edges = np.histogram(data, bins=bin_edges)\n",
    "\n",
    "print(\"频数：\", hist)\n",
    "print(\"分箱边界：\", bin_edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aix",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
